#########################################################################################
### Implements Panel Match (see https://github.com/insongkim/PanelMatch) in Python, 
# with some modifications for the treatment status. 
#########################################################################################

from pandas.plotting import register_matplotlib_converters
register_matplotlib_converters()
import numpy as np
import pandas as pd
from scipy import stats
import matplotlib as mpl
mpl.use('tkagg')
import matplotlib.pyplot as plt 
from scipy.spatial import distance
import statsmodels.formula.api as smf
import datetime as dt
from datetime import datetime
import matplotlib.dates as mdates

data_path = "<insert data path>.csv" #filepath for data
treatment_plot_out_path = "<insert output filepath>.png" #filepath to save the treatment plot
matched_pairs_out_path = "<insert output filepath>.xlsx" #filepath to save the matched pairs dataset
LABEL_DAY = 886 #Advertisement labels begin on Jun 4, 2020, which is 886 days from the start of the data on Jan 1, 2018.

#We use a custom set of null/missing value strings because by default, "NA," the iso code for Namibia, will be treated as null.  
NANS = ['-1.#IND', '1.#QNAN', '1.#IND', '-1.#QNAN', '#N/A N/A', '#N/A', 'N/A', 'n/a', '<NA>', '#NA', 'NULL', 'null', 'NaN', '-NaN', 'nan', '-nan', '']

def pd_open(path):
    """
    Open and return a csv file as a Pandas dataframe, treating "NA" as a valid entry instead
    of a null value (which is the default behavior).
    """
    return pd.read_csv(path, keep_default_na=False, na_values=NANS, index_col=None)

def transpose_treatment_effects(treatment_effects):
    """
    Unfurl the treatment effects dataset so that each row is a unique time id. Each
    column is one of the lags.
    """
    transposed = treatment_effects.T.reset_index()
    transposed['timeid'] = transposed['index'].apply(lambda x: int(x.split(',')[1]))
    transposed = transposed.groupby('timeid', as_index=False).mean()
    return transposed
    
def display_treatment(df):
    """
    Make a graph to display treatment status, following the technique in PanelMatch.
    X axis is time, Y axis are the units. Red dots correspond to treated periods and
    blue dots correspond to untreated periods. 

    Treatment column must be binary (True when treated) and must be named 'treatment'. 
    Data must be sorted by timme. 
    """
    #x values is timeid, y value is id
    fig, ax = plt.subplots()
    df = df[['iso', 'id', 'timeid', 'treatment']]
    df['date'] = df['timeid'].apply(lambda x:
        datetime.strptime('20180101', '%Y%m%d') + dt.timedelta(days=x)
    )

    #sort the rows by number of treatment periods so that the most-treated units are at the top
    ranks = df.groupby('id')['treatment'].sum().rank(method='first', ascending=True)
    ranks = ranks.reset_index(drop=False).rename(columns={'treatment':'rank', 'index':'id'})
    df = df.merge(ranks, on='id', how='left', validate='m:1')

    #edit the size and shape of dots so they're easier to see
    colors = ['tab:red' if t else 'tab:blue' for t in df['treatment']]
    ax.scatter(df['date'], df['rank'], color=colors, marker='s', linewidths=1, s=10)

    #format date axis
    ax.xaxis.set_major_locator(mdates.MonthLocator(interval=6)) #display every 6th month
    locator = mdates.AutoDateLocator()
    ax.xaxis.set_major_formatter(mdates.DateFormatter('%b \'%y')) #format like "jun '18"
    fig.autofmt_xdate() #display dates at an angle instead of straight down

    #rename x-value with the iso instead of the numerical values
    id_to_iso = df.drop_duplicates(subset=['rank', 'iso'])
    plt.yticks(id_to_iso['rank'], id_to_iso['iso'], fontsize=8)
    ax.set_ylabel("Country (ISO code)")
    ax.set_xlabel("Date")

    #change figure sizing and spacing
    fig.set_size_inches(12, 16)
    plt.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0,\
            hspace = 0, wspace = 0)
    plt.margins(0,0)
    plt.savefig(out_path, bbox_inches='tight', dpi=200)

def set_treatment(df, x, lag, lead, high=None, low=None):
    """
    Create a new column 'treatment' for a single unit based on the column x. 
    The unit is treated when the value of x rises from below a "low" threshold
    to above a "high" threshold within "lag" periods. Treatment status for the first lag
    periods should be ignored. 

    Additionally, creates a column 'control' to indicate if a period can serve 
    as a control unit. If a row has control==True, its Impressions were "low" for the past
    lag periods. The column 'delta_treatment' indicates when a unit has changed treatment status.

    df: dataframe to use

    x: variable to use to set treatment status

    lag: number of past periods whose treatment or covariate values may influence the present value

    lead: number of future periods for which we consider treatment effects
    """
    if high is None or low is None:
        lower_90th = np.percentile(df[x], 90)
        new = df[df[x] < lower_90th][x]
        low = 0.1 * new.max()
        high = 0.9 * new.max()

    #if x is high, and x was low at any point L lags in the past, mark it as true
    treatment = (df[x] > high) & np.any([df[x].shift(i) < low for i in range(1, lag+1)], axis=0)
    #then, if x is either high or mid, and the preceding time is treated, it's still treated
    treatment |= (df[x] > low) & (treatment.shift(1).fillna(False))

    #delta treatment: if it is untreated for the last L lags, and treated this period.
    delta_treatment = treatment & np.all([~treatment.shift(i).fillna(False) for i in range(1, lag+1)], axis=0)
    #control: if x remains low for all the past L lags
    control = np.all([df[x].shift(i) < low for i in range(lag+1)], axis=0)

    #These 3 lines might pose an issue
    df['treatment'] = treatment
    df['control'] = control
    df['delta_treatment'] = delta_treatment

    return df

def refine_matches(lag_treated, lag_untreated, covars, metric='maha', match_type='match', max_matches=5):
    """
    Refine the matched set based on the covariates. Returns a vector of weights for the matched set, whose
    order corresponds to the order the units appear in lag_untreated. 

    Lag_treated: A dataframe containing a treated unit ending at the time it's first treated, 
    and starting L periods before that. 

    Lag_untreated: Same as above, but it contains all the untreated units over the same time period.
    This is the initial, unrefined matched set.

    covars: list of covariates to match on. These should be columns in the above two dataframes. 

    metric: type of matching to do. Only "maha" for Mahalonobois distance matching is supported.

    match_type: Currently, only "match" for matching, where all units in the matched set have an equal weight,
    is supported. 

    max_matches: Maximum number of units to include in the matched set. 
    """
    out = pd.DataFrame(columns=['id', 'weight', 'dist'])
    out['id'] = lag_untreated['id'].unique()
    out['weight'] = 0
    out['dist'] = 0

    if metric=='maha':
        maha_df = pd.DataFrame() #rows are units and columns are times
        for t, grp in lag_untreated.groupby('timeid', as_index=False):
            treated = lag_treated[lag_treated['timeid'] == t]
            cov = np.linalg.pinv(np.cov(grp[covars].values.T))
            maha_dist = grp.apply(lambda r: distance.mahalanobis(treated[covars], r[covars], cov), axis=1)
            maha_df[t] = list(maha_dist)
        out['dist'] = maha_df.mean(axis=1)
    
    if match_type=='match':
        ranks = out['dist'].rank(method='first')
        #merging in the rank is only needed to produce the dataset containing each treated unit and its matched set
        out = out.merge(ranks.to_frame().rename(columns={'dist':'rank'}), left_index=True, right_index=True)
        out.loc[ranks <= max_matches, 'weight'] = 1/max_matches
    elif match_type=='weight':
        #caliper constraint only assigns weights to the top matches
        ranks = out['dist'].rank(method='first')
        out.loc[ranks <= max_matches, 'weight'] = 1/out.loc[ranks<=max_matches, 'dist']
        out['weight'] = out['weight']/out['weight'].sum()

    return out

def estimate_att(df, x, y, lags, leads, covars):
    """
    Run algorithm for Panel Match. Prints out the average treatment effects among the treated (ATT), 
    and returns the full dataframe of treatment effects. 
    
    Returns: dataframe of treatment effects. Rows are the lags. Each column is one unit at one time;
    the column name is 'u{A},t{B}' where A is the unit ID and B is the time ID. 

    df: Must be sorted by id, timeid columns. No periods may be missing. Must have columns named
    'delta_treatment' and 'control'

    x: not used

    y: outcome variable

    lags number of past periods whose treatment or covariate values may influence the present value

    leads: number of future periods for which we consider treatment effects

    covars: covariates to use for matching. must be a list. 
    
    Iterates through times, and for each unit that changes treatment status it finds the matched set, 
    refines the matched set, computes the outcome of the matched set, and computes the difference-in-
    differences estimator. 
    """
    treatment_effects = pd.DataFrame() #output dataframe. 
    for t in df['timeid'].unique():
        if (t % 10 == 0): print(t) #keep track of our progress as we go; it's slow.
        #last F and first L time periods should be ignored, treatment status or outcomes are undefined. 
        if (t > (df['timeid'].max() - leads)) or (t < lag): continue

        #get the cross-section at this moment in time
        cross_sec = df[(df['timeid'] == t)]
        treated_cross_sec = cross_sec[cross_sec['delta_treatment']]
        if len(treated_cross_sec) == 0: continue

        #Get indices corresponding to the dataframe at these lags and leads.
        lag_idxs = (df['timeid'] <= t) & (df['timeid'] >= t-lags)
        lead_idxs = (df['timeid'] >= t) & (df['timeid'] <= t+leads)

        #Initial, unrefined matched set is where control is True at this time.
        match_ids = cross_sec[cross_sec['control']]['id']

        #Get the lags of the unrefined matched set.
        lag_untreated = df[lag_idxs & df['id'].isin(match_ids)]
        
        #Here, r is the row containing a treated unit at the moment it switches to becoming treated
        for i, r in treated_cross_sec.iterrows():
            #Get the set of lags and leads for the treated units
            lag_treated = df[lag_idxs & (df['id'] == r['id'])]
            lead_y_treated = df[lead_idxs & (df['id'] == r['id'])][y] #only care about the outcome for leads

            #Compute the weights from the lagged sets; merge onto the set of leads and compute the weighted outcome.
            weights = refine_matches(lag_treated, lag_untreated, covars)[['id', 'weight']]
            lead_y_untreated = df[lead_idxs & df['id'].isin(match_ids)]
            lead_y_untreated = lead_y_untreated.merge(weights, on='id', how='left', validate='m:1')
            lead_y_untreated['outcome'] = lead_y_untreated[y] * lead_y_untreated['weight']
            lead_y_untreated = lead_y_untreated.groupby('timeid', as_index=False).sum() #weighted sum: aggregate across units

            #Get the pre-treatment treated unit outcome at time t-lags
            first_lag_treated = df[(df['timeid'] == t-lags) & (df['id'] == r['id'])][y]

            #Get the pre-treatment, untreated unit outcome, also using the weighted sum
            first_lag_untreated = df[(df['timeid'] == t-lags) & df['id'].isin(match_ids)]
            first_lag_untreated = first_lag_untreated.merge(weights, on='id', how='left', validate='1:1')
            first_lag_untreated['outcome'] = first_lag_untreated[y] * first_lag_untreated['weight']
            first_lag_untreated = first_lag_untreated.groupby('timeid', as_index=False).sum()

            #Treatment effect is the difference between the changes over time for treated and untreated groups 
            treated_delta = np.array(lead_y_treated) - np.array(first_lag_treated)
            untreated_delta = np.array(lead_y_untreated['outcome']) - np.array(first_lag_untreated['outcome'])
            dif = np.subtract(treated_delta, untreated_delta)

            treatment_effects['{},{}'.format(r['id'], t)] = dif
            
    for i, row in treatment_effects.iterrows():
        print(stats.bayes_mvs(row, alpha=0.95)[0])
    return treatment_effects

def evaluate_labels(df, x, y, lags, leads, covars):
    """
    Evaluate the treatment effects associated with Impressions before and after 
    the change in Facebook's policy on state-funding labels.
    """
    #First, manually find the thresholds for treatments from the original impressions column
    lower_90th = np.percentile(df[x], 90)
    new = df[df[x] < lower_90th][x]
    low = 0.1 * new.max()
    high = 0.9 * new.max()

    #apply these thresholds when setting treatment; copy the df because set_treatmnet modifies in-place
    df_label = set_treatment(df.copy(), 'impressions_with_labels', lag, lead, high, low)
    df_unlabel = set_treatment(df.copy(), 'impressions_without_labels', lag, lead, high, low)

    #get treatment effects for each
    treatment_effects_unlabel = estimate_att(df_unlabel, x, y, lags, leads, covars)
    treatment_effects_label = estimate_att(df_label, x, y, lags, leads, covars)

    #process treatment effects
    unlabeled_graph = transpose_treatment_effects(treatment_effects_unlabel)
    labeled_graph = transpose_treatment_effects(treatment_effects_label)
    treated_pre = labeled_graph[labeled_graph['timeid']<LABEL_DAY]
    treated_post = labeled_graph[labeled_graph['timeid']>LABEL_DAY]
    untreated_pre = unlabeled_graph[unlabeled_graph['timeid']<LABEL_DAY]
    untreated_post = unlabeled_graph[unlabeled_graph['timeid']>LABEL_DAY]

    #Print results
    for i in range(8):
        t1 = treated_pre[i].mean()
        t2 = treated_post[i].mean()
        u1 = untreated_pre[i].mean()
        u2 = untreated_post[i].mean()
        fx = (t2 - t1) - (u2 - u1)
        print("{}. Treatment effect: {}".format(i, fx))

def matched_sets_dataset(df, x, y, lags, leads, covars):
    """
    Create a dataset that allows side-by-side comparisons between the treated units and 
    their matched set. Each row contains a treated unit, the date that unit was treated, one control unit 
    in the matched set, the mean Mahalanobis distance between the matching and treated units' covariates,
    the rank of the control unit within the matched set (ranked in ascending order by Mahalanobis 
    distance), and mean covariate values of the treated unit and the control unit. Each row represents
    the mean of 11 periods (the date of treatment and the preceding 10 days). 

    df: Must be sorted by id, timeid columns. No periods may be missing. Must have columns named
    'delta_treatment' and 'control'

    x: not used

    y: outcome variable

    lags number of past periods whose treatment or covariate values may influence the present value

    leads: number of future periods for which we consider treatment effects

    covars: covariates to use for matching. must be a list. 

    Follows a similar procedure to estimate_att but compiles a dataframe instead of actually estimating
    the treatment effect.     
    """
    #for each point K with delta_treatment true, collect the set of unit-times with 'control' true at that time
    treatment_pairs = pd.DataFrame() #output dataframe
    for t in df['timeid'].unique():
        if (t % 10 == 0): print(t) #keep track of our progress as we go; it's slow.
        #last F and first L time periods should be ignored, treatment status or outcomes are undefined. 
        if (t > (df['timeid'].max() - leads)) or (t < lag): continue

        #get the cross-section at this moment in time
        cross_sec = df[(df['timeid'] == t)]
        treated_cross_sec = cross_sec[cross_sec['delta_treatment']]
        if len(treated_cross_sec) == 0: continue

        #Get indices corresponding to the dataframe at these lags and leads.
        lag_idxs = (df['timeid'] <= t) & (df['timeid'] >= t-lags)
        lead_idxs = (df['timeid'] >= t) & (df['timeid'] <= t+leads)

        #Initial, unrefined matched set is where control is True at this time.
        match_ids = cross_sec[cross_sec['control']]['id']

        #Get the lags of the unrefined matched set.
        lag_untreated = df[lag_idxs & df['id'].isin(match_ids)]
        
        #Here, r is the row containing a treated unit at the moment it switches to becoming treated
        for i, r in treated_cross_sec.iterrows():
            #Get the set of lags and leads for the treated units
            lag_treated = df[lag_idxs & (df['id'] == r['id'])]

            #Compute the weights from the lagged sets; merge onto the set of lags
            weights = refine_matches(lag_treated, lag_untreated, covars)
            lag_untreated_weights = lag_untreated.merge(weights, on='id', how='left', validate='m:1')

            #compute the mean (over time) values of the covariates for each country
            #weight, distance, rank are the same for each country, so are unaffected by .mean()
            rows_untreated = lag_untreated_weights.groupby(
                ['id', 'iso', 'country'], as_index=False).mean()[['country'] + covars + ['weight', 'dist', 'rank']]
            rows_untreated = rows_untreated.sort_values(by='rank') #put in order of rank

            #indicate whether each untreated unit has positive events
            #it does if the "artificial control unit" (weighted sum of the refined matched set) has positive events
            lag_untreated_weights['events_weighted'] = lag_untreated_weights['weight'] * lag_untreated_weights['events']
            lag_untreated_matched_events = lag_untreated_weights.groupby('timeid', as_index=False).sum()['events_weighted']
            rows_untreated['has_positive_events_matched_set'] = 0
            #only rows within the refined match set are indicated as having positive events
            rows_untreated.loc[rows_untreated['rank']<=5, 'has_positive_events_matched_set'] = int(np.sum(lag_untreated_matched_events > 0) > 0)

            #repeat the same for the treated unit
            row_treated = lag_treated.groupby(['id', 'iso', 'country'], as_index=False).mean()[['country'] + covars]
            row_treated['has_positive_events_treated'] = int(np.sum(lag_treated['events'] > 0) > 0)

            #merge on a temporary column
            row_treated['tmp'] = 1
            rows_untreated['tmp'] = 1
            new_rows = row_treated.merge(rows_untreated, on='tmp', suffixes=['_treated', '_untreated'], validate='1:m').drop(columns=['tmp'])

            #the treatment date is the last date within the lags
            new_rows['date_treated'] = lag_treated['date'].max()

            #add our newly created row to the dataframe
            treatment_pairs = treatment_pairs.append(new_rows, ignore_index=True)

    cols = ['date_treated'] + treatment_pairs.columns[:-1].to_list()
    treatment_pairs = treatment_pairs[cols]
    treatment_pairs.to_excel(matched_pairs_out_path)


#Number of days for lag and leads
lag = 10
lead = 7
#list of columns in the dataframe to use as covariates, for matching
covars = ['events', 'arm_imports', 'exports', 'imports', 'fdi', 'diplomatic']
x = 'impressions' #variable used to set treatment status
y = 'gdelt_sentiment' #outcome variable

df = pd_open(data_path)


#Set treatment statuses, and then make a treatment plot, estimate ATT, or make matched sets dataset
df = set_treatment(df, x, lag=lag, lead=lead)
display_treatment(df)
estimate_att(df, x, y, lag, lead, covars)
matched_sets_dataset(df, x, y, lag, lead, covars)

#Evaluate labels - do NOT set treatment before running this function:
evaluate_labels(df, x, y, lag, lead, covars)